{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "语言生成中，一次生成一个单词，然后将该单词添加到输入中，预测下一个单词；\n",
    "- 每一个时间步，模型给出的都是基于历史生成结果的条件概率。\n",
    "  \n",
    "设输出文本词典 $\\mathcal{Y}$（包含特殊符号`<eos>`）的大小为 $\\left|\\mathcal{Y}\\right|$，输出序列的最大长度为 $T′$。所有可能的输出序列一共有 $\\mathcal{O}(\\left|\\mathcal{Y}\\right|^{T'})$ 种。这些输出序列中所有特殊符号`<eos>`后面的子序列将被舍弃。\n",
    "\n",
    "如何从所有序列中选择需要的序列，获得完整的句子？   \n",
    "需要一个称为 **解码** 的额外动作来融合模型多个时间步的输出，而且使得最终得到的序列概率最大。\n",
    ">- 自回归语言生成基于一个假设，即生成的单词序列的概率分布可以分解为每一步生成单词的概率的连乘\n",
    "$$P(w_{1:T}|w_0) = \\prod_{t=1}^TP(w_{t}|w_{1:t-1},w_0)$$\n",
    "其中 $w_0$ 为起始符号，生成`EOS`时停止\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 贪婪搜索(Greedy Search):\n",
    "\n",
    "- 在每个时间步，选择概率最高的单词作为结果\n",
    "$$y_{t'} = \\operatorname*{argmax}_{y \\in \\mathcal{Y}} P(y \\mid y_1, \\ldots, y_{t'-1}, \\mathbf{c})$$ \n",
    "如下图所示，得到的生成输出序列 ` ABC<eos> `。该输出序列的条件概率是      $0.5\\times0.4\\times0.4\\times0.6 = 0.048$\n",
    "<img src=\"../images/greedy-search.svg\" width=\"40%\">  \n",
    "   \n",
    "      \n",
    "         \n",
    "- 在时间步 2 选择概率第 2 大的词 `C`，由于时间步3所基于的时间步1和2的输出子序列由上图中的 `“A”“B”`变为的`“A”“C”`，则时间步 3 生成各个词的条件概率发生了变化。\n",
    "<img src=\"../images/greedy-search2.svg\" width=\"40%\">  \n",
    "此时的输出序列 `“A”“C”“B”“<eos>”`的条件概率是 $0.5\\times0.3\\times0.6\\times0.6=0.054$，大于贪婪搜索得到的输出序列的条件概率。\n",
    ">因此，**贪婪搜索得到的输出序列并非最优输出序列。**\n",
    "   \n",
    "     \n",
    "- 贪心搜索产生的结果很快会出现重复，其主要确定是，错过了隐藏在低概率单词后面的高概率单词"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-05-12T10:38:28.719260Z",
     "start_time": "2020-05-12T10:38:27.594535Z"
    }
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "from transformers import AutoTokenizer, AutoModelWithLMHead"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-05-12T10:38:33.305947Z",
     "start_time": "2020-05-12T10:38:29.650898Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Generated :\n",
      " The dog was found around 7:20 a.m. on the ground. No foul play was found, and the dog was transported to an area hospital that is investigating the incident.\n",
      "\n",
      "\n",
      "\"\n"
     ]
    }
   ],
   "source": [
    "tokenizer = AutoTokenizer.from_pretrained('../../H/models/huggingface/gpt2')\n",
    "model = AutoModelWithLMHead.from_pretrained('../../H/models/huggingface/gpt2')\n",
    "input_context = 'The dog'\n",
    "input_ids = tokenizer.encode(input_context, return_tensors='pt')\n",
    "outputs = model.generate(\n",
    "    input_ids=input_ids,\n",
    "    max_length=40,\n",
    ")\n",
    "print('Generated :\\n {}'.format(\n",
    "    tokenizer.decode(outputs[0], skip_special_tokens=True)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 穷举搜索(exhausive search)\n",
    "\n",
    "穷举所有可能的输出序列，得到条件概率最大的序列；计算开销  $\\mathcal{O}(\\left|\\mathcal{Y}\\right|^{T'})$ 很容易过大。例如，当$|\\mathcal{Y}|=10000$且$T'=10$时，我们将评估$10000^{10} = 10^{40}$个序列"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 集束搜索(Beam Search)\n",
    "- 在每个时间步，保留最可能的几个(num_beams)单词，最终选择总体概率最高的假设；相当于形成一个多叉树，最终选择数中概率最高的那条路径；\n",
    "<img src=\"../images/beam_search.svg\" width=\"80%\">  \n",
    "$$\\frac{1}{L^\\alpha} \\log P(y_1, \\ldots, y_{L}) = \\frac{1}{L^\\alpha} \\sum_{t'=1}^L \\log P(y_{t'} \\mid y_1, \\ldots, y_{t'-1}, \\boldsymbol{c}),$$\n",
    "\n",
    "其中L为最终候选序列长度，$\\alpha$一般可选为0.75。分母上的$L^\\alpha$是为了惩罚较长序列在以上分数中较多的对数相加项。束搜索的计算开销为$\\mathcal{O}(k\\left|\\mathcal{Y}\\right|T')$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class BeamHypotheses:\n",
    "    def __init__(self, num_beams, max_length, length_penalty):\n",
    "        self.max_length = max_length - 1  # ignoring bos_token\n",
    "        self.num_beams = num_beams\n",
    "        self.beams = []\n",
    "        self.worst_score = 1e9\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.beams)\n",
    "\n",
    "    def add(self, hyp, sum_logprobs):\n",
    "        score = sum_logprobs / len(hyp)**self.length_penalty\n",
    "        if len(self) < self.num_beams or score > self.worst_score:\n",
    "            # 可更新的情况：数量未饱和或超过最差得分\n",
    "            self.beams.append((score, hyp))\n",
    "            if len(self) > self.num_beams:\n",
    "                # 数量饱和需要删掉一个最差的\n",
    "                sorted_scores = sorted([\n",
    "                    (s, idx) for idx, (s, _) in enumerate(self.beams)\n",
    "                ])\n",
    "                del self.beams[sorted_scores[0][1]]\n",
    "                self.worst_score = sorted_scores[1][0]\n",
    "            else:\n",
    "                self.worst_score = min(score, self.worst_score)\n",
    "\n",
    "    def is_done(self, best_sum_logprobs, cur_len=None):\n",
    "        \"\"\"\n",
    "        相关样本是否已经完成生成。\n",
    "        best_sum_logprobs是新的候选序列中的最高得分。\n",
    "        \"\"\"\n",
    "\n",
    "        if len(self) < self.num_beams:\n",
    "            return False\n",
    "        else:\n",
    "            if cur_len is None:\n",
    "                cur_len = self.max_length\n",
    "            cur_score = best_sum_logprobs / cur_len**self.length_penalty\n",
    "            # 是否最高分比当前保存的最低分还差\n",
    "            ret = self.worst_score >= cur_score\n",
    "            return ret"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def beam_search_generate(context,\n",
    "                         batch_size=3,\n",
    "                         max_length=20,\n",
    "                         min_length=2,\n",
    "                         num_beams=2,\n",
    "                         bos_token_id=101,\n",
    "                         pad_token_id=0,\n",
    "                         eos_token_id=102):\n",
    "    \"\"\"\n",
    "    context:编码器编码获得的向量\n",
    "    batch_size:每批数据中包含的样本量\n",
    "    bos_token_id:句子开头 token id\n",
    "    pad_token_id:填充的 token id\n",
    "    eos_token_id:句子结束标记的 token id\n",
    "    \"\"\"\n",
    "    # 建立beam容器，每个样本一个\n",
    "    generated_hyps = [\n",
    "        BeamHypotheses(num_beams,\n",
    "                       max_length,\n",
    "                       length_penalty,\n",
    "                       early_stopping=early_stopping)\n",
    "        for _ in range(batch_size)\n",
    "    ]\n",
    "\n",
    "    # 每个beam容器的得分，共batch_size*num_beams个\n",
    "    beam_scores = torch.zeros((batch_size, num_beams),\n",
    "                              dtype=torch.float,\n",
    "                              device=encoder_input_ids.device)\n",
    "    beam_scores = beam_scores.view(-1)\n",
    "\n",
    "    # 每个样本是否完成生成，共batch_size个\n",
    "    done = [False for _ in range(batch_size)]\n",
    "\n",
    "    # 为了并行计算，一次生成batch_size*num_beams个序列\n",
    "    # 第一步自动填入bos_token\n",
    "    input_ids = torch.full(\n",
    "        (batch_size * num_beams, 1),\n",
    "        bos_token_id,\n",
    "        dtype=torch.long,\n",
    "        device=next(self.parameters()).device,\n",
    "    )\n",
    "\n",
    "    # 当前长度设为1\n",
    "    cur_len = 1\n",
    "\n",
    "    while cur_len < max_length:\n",
    "        # 将编码器得到的上下文向量和当前结果输入解码器，即图中1\n",
    "        output = decoder.decode_next_step(context, input_ids)\n",
    "        # 输出矩阵维度为：(batch*num_beams)*cur_len*vocab_size\n",
    "\n",
    "        # 取出最后一个时间步的各token概率，即当前条件概率\n",
    "        # (batch*num_beams)*vocab_size\n",
    "        scores = next_token_logits = output[:, -1, :]\n",
    "\n",
    "        ###########################\n",
    "        # 这里可以做一大堆操作减少重复 #\n",
    "        ###########################\n",
    "\n",
    "        # 计算序列条件概率的，因为取了log，所以直接相加即可。得到图中2矩阵\n",
    "        # (batch_size * num_beams, vocab_size)\n",
    "        next_scores = scores + beam_scores[:, None].expand_as(scores)\n",
    "\n",
    "        # 为了提速，将结果重排成图中3的形状\n",
    "        next_scores = next_scores.view(\n",
    "            batch_size,\n",
    "            num_beams * vocab_size)  # (batch_size, num_beams * vocab_size)\n",
    "\n",
    "        # 取出分数最高的token（图中黑点）和其对应得分\n",
    "        # sorted=True，保证返回序列是有序的\n",
    "        next_scores, next_tokens = torch.topk(next_scores,\n",
    "                                              2 * num_beams,\n",
    "                                              dim=1,\n",
    "                                              largest=True,\n",
    "                                              sorted=True)\n",
    "\n",
    "        # 下一个时间步整个batch的beam列表\n",
    "        # 列表中的每一个元素都是三元组\n",
    "        # (分数, token_id, beam_id)\n",
    "        next_batch_beam = []\n",
    "\n",
    "        # 对每一个样本进行扩展\n",
    "        for batch_idx in range(batch_size):\n",
    "\n",
    "            # 检查样本是否已经生成结束\n",
    "            if done[batch_idx]:\n",
    "                # 对于已经结束的句子，待添加的是pad token\n",
    "                next_batch_beam.extend([(0, pad_token_id, 0)] *\n",
    "                                       num_beams)  # pad the batch\n",
    "                continue\n",
    "\n",
    "            # 当前样本下一个时间步的beam列表\n",
    "            next_sent_beam = []\n",
    "\n",
    "            # 对于还未结束的样本需要找到分数最高的num_beams个扩展\n",
    "            # 注意，next_scores和next_tokens是对应的\n",
    "            # 而且已经按照next_scores排好顺序\n",
    "            for beam_token_rank, (beam_token_id,\n",
    "                                  beam_token_score) in enumerate(\n",
    "                                      zip(next_tokens[batch_idx],\n",
    "                                          next_scores[batch_idx])):\n",
    "                # get beam and word IDs\n",
    "                # 这两行可参考图中3进行理解\n",
    "                beam_id = beam_token_id // vocab_size\n",
    "                token_id = beam_token_id % vocab_size\n",
    "\n",
    "                effective_beam_id = batch_idx * num_beams + beam_id\n",
    "\n",
    "                # 如果出现了EOS token说明已经生成了完整句子\n",
    "                if (eos_token_id is\n",
    "                        not None) and (token_id.item() == eos_token_id):\n",
    "                    # if beam_token does not belong to top num_beams tokens, it should not be added\n",
    "                    is_beam_token_worse_than_top_num_beams = beam_token_rank >= num_beams\n",
    "                    if is_beam_token_worse_than_top_num_beams:\n",
    "                        continue\n",
    "                    # 往容器中添加这个序列\n",
    "                    generated_hyps[batch_idx].add(\n",
    "                        input_ids[effective_beam_id].clone(),\n",
    "                        beam_token_score.item(),\n",
    "                    )\n",
    "                else:\n",
    "                    # add next predicted word if it is not eos_token\n",
    "                    next_sent_beam.append(\n",
    "                        (beam_token_score, token_id, effective_beam_id))\n",
    "\n",
    "                # 扩展num_beams个就够了\n",
    "                if len(next_sent_beam) == num_beams:\n",
    "                    break\n",
    "\n",
    "            # 检查这个样本是否已经生成完了，有两种情况\n",
    "            # 1. 已经记录过该样本结束\n",
    "            # 2. 新的结果没有使结果改善\n",
    "            done[batch_idx] = done[batch_idx] or generated_hyps[\n",
    "                batch_idx].is_done(next_scores[batch_idx].max().item(),\n",
    "                                   cur_len=cur_len)\n",
    "\n",
    "            # 把当前样本的结果添加到batch结果的后面\n",
    "            next_batch_beam.extend(next_sent_beam)\n",
    "\n",
    "        # 如果全部样本都已经生成结束便可以直接退出了\n",
    "        if all(done):\n",
    "            break\n",
    "\n",
    "        # 把三元组列表再还原成三个独立列表\n",
    "        beam_scores = beam_scores.new([x[0] for x in next_batch_beam])\n",
    "        beam_tokens = input_ids.new([x[1] for x in next_batch_beam])\n",
    "        beam_idx = input_ids.new([x[2] for x in next_batch_beam])\n",
    "\n",
    "        # 准备下一时刻的解码器输入\n",
    "        # 取出实际被扩展的beam\n",
    "        input_ids = input_ids[beam_idx, :]\n",
    "        # 在这些beam后面接上新生成的token\n",
    "        input_ids = torch.cat([input_ids, beam_tokens.unsqueeze(1)], dim=-1)\n",
    "\n",
    "        # 更新当前长度\n",
    "        cur_len = cur_len + 1\n",
    "        # end of length while\n",
    "\n",
    "    # 将未结束的生成结果结束，并置入容器中\n",
    "    for batch_idx in range(batch_size):\n",
    "        # 已经结束的样本不需处理\n",
    "        if done[batch_idx]:\n",
    "            continue\n",
    "\n",
    "        # 把结果加入到generated_hyps容器\n",
    "        for beam_id in range(num_beams):\n",
    "            effective_beam_id = batch_idx * num_beams + beam_id\n",
    "            final_score = beam_scores[effective_beam_id].item()\n",
    "            final_tokens = input_ids[effective_beam_id]\n",
    "            generated_hyps[batch_idx].add(final_tokens, final_score)\n",
    "\n",
    "    # select the best hypotheses，最终输出\n",
    "    # 每个样本返回几个句子\n",
    "    output_num_return_sequences_per_batch = 1\n",
    "    # 记录每个返回句子的长度，用于后面pad\n",
    "    sent_lengths = input_ids.new(output_batch_size)\n",
    "    best = []\n",
    "\n",
    "    # 对每个样本取出最好的output_num_return_sequences_per_batch个句子\n",
    "    for i, hypotheses in enumerate(generated_hyps):\n",
    "        sorted_hyps = sorted(hypotheses.beams, key=lambda x: x[0])\n",
    "        for j in range(output_num_return_sequences_per_batch):\n",
    "            effective_batch_idx = output_num_return_sequences_per_batch * i + j\n",
    "            best_hyp = sorted_hyps.pop()[1]\n",
    "            sent_lengths[effective_batch_idx] = len(best_hyp)\n",
    "            best.append(best_hyp)\n",
    "\n",
    "    # 如果长短不一则pad句子，使得最后返回结果的长度一样\n",
    "    if sent_lengths.min().item() != sent_lengths.max().item():\n",
    "        sent_max_len = min(sent_lengths.max().item() + 1, max_length)\n",
    "        # 先把输出矩阵填满PAD token\n",
    "        decoded = input_ids.new(output_batch_size,\n",
    "                                sent_max_len).fill_(pad_token_id)\n",
    "\n",
    "        # 填入真正的内容\n",
    "        for i, hypo in enumerate(best):\n",
    "            decoded[i, :sent_lengths[i]] = hypo\n",
    "            # 填上eos token\n",
    "            if sent_lengths[i] < max_length:\n",
    "                decoded[i, sent_lengths[i]] = eos_token_id\n",
    "    else:\n",
    "        # 所有生成序列都还没结束，直接堆叠即可\n",
    "        decoded = torch.stack(best).type(torch.long).to(\n",
    "            next(self.parameters()).device)\n",
    "\n",
    "    # 返回的结果包含BOS token\n",
    "    return decoded"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-05-12T10:42:00.781193Z",
     "start_time": "2020-05-12T10:41:59.148591Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Generated :\n",
      " The dog was found around 7:20 a.m. on the ground. No foul play was found, and the dog was transported to an area hospital that is investigating the incident.\n",
      "\n",
      "\n",
      "\"\n"
     ]
    }
   ],
   "source": [
    "beam_output = model.generate(input_ids, max_length=50, num_beams=5)\n",
    "print('Generated :\\n {}'.format(\n",
    "    tokenizer.decode(outputs[0], skip_special_tokens=True)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "beam_output = model.generate(input_ids,\n",
    "                             max_length=50,\n",
    "                             num_beams=5,\n",
    "                             no_repeat_ngram_size=2,\n",
    "                             early_stopping=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "- 集束搜索可以很好地完成任务——在机器翻译或摘要中，期望生成的长度或多或少是可预测的。\n",
    "- 但在开放式的生成，期望输出的长度可能会有很大的变化，例如对话和故事生成，集束搜索可能不是最好选择的问题\n",
    "- Beam Search虽然比贪心强了不少，但还是会生成出空洞、重复、前后矛盾的文本。《The Curious Case of Neural Text Degeneration》[1]论文认为这种问题是由于这种试图最大化序列条件概率的解码策略从根上就有问题"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 采样(sampling)\n",
    "- 生成单词时不再时选择概率最高的单词，而是随机采样；此时语言生成不再是确定性的，容易产生前后不一致的问题。但在在开放闲聊领域，生成文本的长度都比较短，这种问题就被自然的淡化了。\n",
    "- 通过改变 `softmax` 输出时的超参 `temperature`可以控制概率分布的形貌，当$T$大的时候，概率分布趋向平均，随机性增大；当小的时候，概率密度趋向于集中，即强者愈强，随机性降低。降低$T$，提高高概率单词被采样的可能性，降低生成的随机性\n",
    "$$p_i = \\frac{exp(y_{i}/T)}{\\sum exp(y_{i}/T)}$$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "beam_output = model.generate(input_ids,\n",
    "                             max_length=50,\n",
    "                             do_sample=True,\n",
    "                             top_k=0,\n",
    "                             temperature=0.7)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Top-K 抽样\n",
    "- 在采样前将输出的概率分布截断，取出概率最大的`k`个词构成一个集合，然后将这个子集词的概率再归一化，最后从新的概率分布中采样词汇。\n",
    "- 难点在于超参 `k` 的选择。因为这个概率分布变化比较大，有时候可能很均匀(flat)，有的时候比较集中(peaked)。对于集中的情况还好说，当分布均匀时，一个较小的k容易丢掉很多优质候选词。但如果k定的太大，这个方法又会退化回普通采样。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "beam_output = model.generate(input_ids,\n",
    "                             max_length=50,\n",
    "                             do_sample=True,\n",
    "                             top_k=50)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Top-p (nucleus) sampling\n",
    "- 改在概率从高到底选择输出，构造一个最小候选集 $V$，使得 $$\\sum_{x\\in V}P(x)>p$$\n",
    "- 然后重新归一化集合内词的概率，并从中采样\n",
    "- 也可以将 Top-k 和 Top-p 两者结合"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "beam_output = model.generate(input_ids,\n",
    "                             max_length=50,\n",
    "                             do_sample=True,\n",
    "                             top_k=0,\n",
    "                             top_p=0.92)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 惩罚重复\n",
    "- 仍然会出现重复，加入 `n_grams` 惩罚，保证 `n_grams` 不会出现两次。\n",
    "- 必须谨慎使用"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def top_k_top_p_filtering(logits,\n",
    "                          top_k=0,\n",
    "                          top_p=1.0,\n",
    "                          filter_value=-float(\"Inf\"),\n",
    "                          min_tokens_to_keep=1):\n",
    "    \"\"\" Filter a distribution of logits using top-k and/or nucleus (top-p) filtering\n",
    "        Args:\n",
    "            logits: logits distribution shape (batch size, vocabulary size)\n",
    "            if top_k > 0: keep only top k tokens with highest probability (top-k filtering).\n",
    "            if top_p < 1.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).\n",
    "                Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)\n",
    "            Make sure we keep at least min_tokens_to_keep per batch example in the output\n",
    "        From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317\n",
    "    \"\"\"\n",
    "    if top_k > 0:\n",
    "        top_k = min(max(top_k, min_tokens_to_keep),\n",
    "                    logits.size(-1))  # Safety check\n",
    "        # Remove all tokens with a probability less than the last token of the top-k\n",
    "        indices_to_remove = logits < torch.topk(logits,\n",
    "                                                top_k)[0][..., -1, None]\n",
    "        logits[indices_to_remove] = filter_value\n",
    "\n",
    "    if top_p < 1.0:\n",
    "        sorted_logits, sorted_indices = torch.sort(logits, descending=True)\n",
    "        cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1),\n",
    "                                        dim=-1)\n",
    "\n",
    "        # Remove tokens with cumulative probability above the threshold (token with 0 are kept)\n",
    "        sorted_indices_to_remove = cumulative_probs > top_p\n",
    "        if min_tokens_to_keep > 1:\n",
    "            # Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below)\n",
    "            sorted_indices_to_remove[..., :min_tokens_to_keep] = 0\n",
    "        # Shift the indices to the right to keep also the first token above the threshold\n",
    "        sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[\n",
    "            ..., :-1].clone()\n",
    "        sorted_indices_to_remove[..., 0] = 0\n",
    "\n",
    "        # scatter sorted tensors to original indexing\n",
    "        indices_to_remove = sorted_indices_to_remove.scatter(\n",
    "            1, sorted_indices, sorted_indices_to_remove)\n",
    "        logits[indices_to_remove] = filter_value\n",
    "    return logits"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def enforce_repetition_penalty_(self, lprobs, batch_size, num_beams,\n",
    "                                prev_output_tokens, repetition_penalty):\n",
    "    \"\"\"repetition penalty (from CTRL paper https://arxiv.org/abs/1909.05858). \"\"\"\n",
    "    for i in range(batch_size * num_beams):\n",
    "        for previous_token in set(prev_output_tokens[i].tolist()):\n",
    "            # if score < 0 then repetition penalty has to multiplied to reduce the previous token probability\n",
    "            if lprobs[i, previous_token] < 0:\n",
    "                lprobs[i, previous_token] *= repetition_penalty\n",
    "            else:\n",
    "                lprobs[i, previous_token] /= repetition_penalty"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def calc_banned_tokens(prev_input_ids, num_hypos, no_repeat_ngram_size,\n",
    "                       cur_len):\n",
    "    # Copied from fairseq for no_repeat_ngram in beam_search\"\"\"\n",
    "    if cur_len + 1 < no_repeat_ngram_size:\n",
    "        # return no banned tokens if we haven't generated no_repeat_ngram_size tokens yet\n",
    "        return [[] for _ in range(num_hypos)]\n",
    "    generated_ngrams = [{} for _ in range(num_hypos)]\n",
    "    for idx in range(num_hypos):\n",
    "        gen_tokens = prev_input_ids[idx].numpy().tolist()\n",
    "        generated_ngram = generated_ngrams[idx]\n",
    "        # 就是这巧妙的一句\n",
    "        for ngram in zip(\n",
    "                *[gen_tokens[i:] for i in range(no_repeat_ngram_size)]):\n",
    "            prev_ngram_tuple = tuple(ngram[:-1])\n",
    "            generated_ngram[prev_ngram_tuple] = generated_ngram.get(\n",
    "                prev_ngram_tuple, []) + [ngram[-1]]\n",
    "\n",
    "    def _get_generated_ngrams(hypo_idx):\n",
    "        # Before decoding the next token, prevent decoding of ngrams that have already appeared\n",
    "        start_idx = cur_len + 1 - no_repeat_ngram_size\n",
    "        ngram_idx = tuple(\n",
    "            prev_input_ids[hypo_idx, start_idx:cur_len].numpy().tolist())\n",
    "        return generated_ngrams[hypo_idx].get(ngram_idx, [])\n",
    "\n",
    "    banned_tokens = [\n",
    "        _get_generated_ngrams(hypo_idx) for hypo_idx in range(num_hypos)\n",
    "    ]\n",
    "    return banned_tokens"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if do_sample:\n",
    "    # 这是今天的采样方式\n",
    "    _scores = scores + beam_scores[:, None].expand_as(\n",
    "        scores)  # (batch_size * num_beams, vocab_size)\n",
    "    # Top-p/top-k filtering，这一步重建了候选集\n",
    "    _scores = top_k_top_p_filtering(\n",
    "        _scores, top_k=top_k, top_p=top_p,\n",
    "        min_tokens_to_keep=2)  # (batch_size * num_beams, vocab_size)\n",
    "    # re-organize to group the beam together to sample from all beam_idxs\n",
    "    _scores = _scores.contiguous().view(\n",
    "        batch_size,\n",
    "        num_beams * vocab_size)  # (batch_size, num_beams * vocab_size)\n",
    "\n",
    "    # Sample 2 next tokens for each beam (so we have some spare tokens and match output of greedy beam search)\n",
    "    probs = F.softmax(_scores, dim=-1)\n",
    "    # 采样\n",
    "    next_tokens = torch.multinomial(probs, num_samples=2 *\n",
    "                                    num_beams)  # (batch_size, num_beams * 2)\n",
    "    # Compute next scores\n",
    "    next_scores = torch.gather(_scores, -1,\n",
    "                               next_tokens)  # (batch_size, num_beams * 2)\n",
    "    # sort the sampled vector to make sure that the first num_beams samples are the best\n",
    "    next_scores, next_scores_indices = torch.sort(next_scores,\n",
    "                                                  descending=True,\n",
    "                                                  dim=1)\n",
    "    next_tokens = torch.gather(\n",
    "        next_tokens, -1, next_scores_indices)  # (batch_size, num_beams * 2)\n",
    "else:\n",
    "    # 这是昨天的beam search方式\n",
    "    # 直接将log概率相加求条件概率\n",
    "    next_scores = scores + beam_scores[:, None].expand_as(\n",
    "        scores)  # (batch_size * num_beams, vocab_size)\n",
    "\n",
    "    # re-organize to group the beam together (we are keeping top hypothesis accross beams)\n",
    "    next_scores = next_scores.view(\n",
    "        batch_size,\n",
    "        num_beams * vocab_size)  # (batch_size, num_beams * vocab_size)\n",
    "\n",
    "    next_scores, next_tokens = torch.topk(next_scores,\n",
    "                                          2 * num_beams,\n",
    "                                          dim=1,\n",
    "                                          largest=True,\n",
    "                                          sorted=True)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.4"
  },
  "toc": {
   "base_numbering": 1,
   "nav_menu": {},
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": false,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": false,
   "toc_position": {},
   "toc_section_display": true,
   "toc_window_display": false
  },
  "varInspector": {
   "cols": {
    "lenName": 16,
    "lenType": 16,
    "lenVar": 40
   },
   "kernels_config": {
    "python": {
     "delete_cmd_postfix": "",
     "delete_cmd_prefix": "del ",
     "library": "var_list.py",
     "varRefreshCmd": "print(var_dic_list())"
    },
    "r": {
     "delete_cmd_postfix": ") ",
     "delete_cmd_prefix": "rm(",
     "library": "var_list.r",
     "varRefreshCmd": "cat(var_dic_list()) "
    }
   },
   "types_to_exclude": [
    "module",
    "function",
    "builtin_function_or_method",
    "instance",
    "_Feature"
   ],
   "window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
